Skip to content

Conversation

@zsotakal
Copy link

@zsotakal zsotakal commented Jan 20, 2026

Proposed changes

Add support for grouped gemm multi ABD fixed NK. MR contains:

  • Device struct for grouped gemm with multiple ABD and fixed NK (DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK).
  • Wmma versions of existing example codes: 59_grouped_gemm_multi_ABD
  • Unit tests for both new wmma implementation and the reference xdl code (previously missing)
    Note: Some Xdl instances were commented out because of unit test failures. As mentioned apparently for xdl this feature was missing tests so our assumption is either there is an implemenetation bug or these instances were not set up correctly. Has the potential for a follow-up issue.
  • Generic ck profiler interface with the purpose of calling unit tests.
  • Gemm instances with specific elementwise operations for gemm bias gelu calculations.
  • Added class for grouped gemm multi ABD reference calculations.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

@afagaj afagaj requested a review from Copilot January 20, 2026 22:09
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements support for grouped GEMM with multiple ABD tensors and fixed NK on RDNA4 architecture, specifically for WMMA implementations. The feature was previously only available for XDL implementations.

Changes:

  • Added WMMA device operator implementations for grouped GEMM multi ABD with fixed NK
  • Unit tests for both new WMMA and existing XDL implementations
  • Reference implementation class for verification
  • Example code demonstrating WMMA usage patterns

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp Unit test framework for validating grouped GEMM multi ABD fixed NK implementations
test/grouped_gemm/CMakeLists.txt Build configuration for new unit test
profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp Generic profiler interface for calling unit tests and benchmarking
profiler/include/profiler/profile_gemm_multi_abd_impl.hpp Refactored to use new reference implementation
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp Commented out failing XDL instances
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp Commented out failing XDL instances
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp WMMA instances for MK-NK-MN layout with bias/gelu operations
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp WMMA instances for MK-KN-MN layout with bias/gelu operations
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp WMMA instances for KM-KN-MN layout with bias/gelu operations
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt Build configuration for new WMMA instances
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp Factory functions for WMMA instances
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp Reference implementation for grouped GEMM multi ABD verification
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Added EDataType_ alias for type access
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp Added hardware support checks and main K block loop validation
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp New WMMA device operator implementation
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp Example using WMMA with FP16 and bias addition
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp Example using WMMA with BF16/I8 and bias+GELU
example/59_grouped_gemm_multi_ABD/CMakeLists.txt Build configuration for WMMA examples

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty conditional block serves no purpose. Remove this branch or add a comment explaining why argc == 1 is explicitly handled (e.g., "use default parameters").

Suggested change
if(argc == 1) {}
if(argc == 1)
{
// use default parameters when no extra arguments are provided
}

Copilot uses AI. Check for mistakes.
auto ref_invoker = ref_gemm.MakeInvoker();

auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference GEMM arguments use variables a_m_k, b_k_n, and c_m_n that are no longer defined in this scope after the refactoring. These variables were computed within the removed reference computation code and need to be generated by the new ReferenceGemmMultiABD class.

Copilot uses AI. Check for mistakes.
Comment on lines +20 to +21
// this function is also defined in CK but because of the way we use it in
// profile_gemm_multi_impl, it requires the arguments to not be const
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add documentation explaining why this duplicate function definition exists and how it differs from the CK version. The comment on line 20-21 is insufficient - it should explain the const-correctness difference and the implications for usage in profile_gemm_multi_impl.

Suggested change
// this function is also defined in CK but because of the way we use it in
// profile_gemm_multi_impl, it requires the arguments to not be const
// NOTE:
// This helper intentionally duplicates `concat_tuple_of_refs` from the core CK utilities,
// but with a different const-correctness contract on its arguments:
//
// - The CK version is defined to operate on (typically) const-qualified tuples of
// references; its parameters are more permissive and can accept `const Tuple<...>&`.
// - This host-side overload is deliberately restricted to *non-const* tuples of
// references: `ck::Tuple<X&...>&` and `ck::Tuple<Y&...>&`.
//
// In `profile_gemm_multi_impl`, we need to concatenate tuples that contain non-const
// references to tensors/buffers so that:
// * The resulting concatenated tuple preserves non-const reference semantics, allowing
// the profiled kernels and host-side utilities to modify the referenced objects, and
// * Overload resolution / SFINAE continues to select APIs that require non-const
// references (these would reject a const-qualified tuple produced by the CK version).
//
// If this function were replaced by the CK version, the arguments in
// `profile_gemm_multi_impl` could become (or be treated as) const, which would either:
// - Prevent intended mutation of the underlying tensors, or
// - Cause subtle compilation or behavior differences due to const propagation.
//
// For that reason, this duplicate, non-const overload must remain local to the host-side
// GEMM multi reference implementation and should not be "simplified" by switching to the
// CK variant without carefully revisiting `profile_gemm_multi_impl` and its call sites.

Copilot uses AI. Check for mistakes.
{
if(arg.grouped_gemm_kernel_args_dev == nullptr)
{
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'nullpr' to 'nullptr'.

Suggested change
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr");

Copilot uses AI. Check for mistakes.
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'n0' to 'no'.

Suggested change
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");

Copilot uses AI. Check for mistakes.
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'n0' to 'no'.

Suggested change
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@ErwinTerpstra ErwinTerpstra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Seems like it was quite the puzzle to get working. I have placed some comments. Some of them are more suggested improvements. Feel free to put those in a follow-up issue if we want to merge this soon.

// instruction that supports bf16 and we cannot use splitk because of that
if constexpr(std::is_same<AsDataType, ck::bhalf_t>::value)
{
supported = supported & (arg.k_batch_ == 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this limitation be only on gfx11? I think gfx12 has the atomatic adds for bf16

typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this parameter is not necessary anymore for the WMMA pipelines


for(index_t i = 0; i < arg.group_count_; i++)
{
if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K) != true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using SplitK, the HasKMainKBlock should be calculated on the K per batch value (like the CalculateHasMainKBlockLoop() in this class does)

barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n);
}

void UpdateKBatch(index_t) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that this method is empty? If KBatch is not supported, it maybe should be removed or at least give a runtime error if you use it.

const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the gfx9 be removed here?

for(int i = 0; i < group_count; i++)
{
a0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this was already in the XDL example, but it seems incorrect to allocate memory according to the "sum_of_m" size. That would mean every element in the group gets allocated the size of all the groups combined. Seems like it should use problem_size.Ms[i] here.

bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be cleaner to also use your new ReferenceGemmMultipleABD implementation here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I've placed some comments on the fp16 example, it would be good to place the common code for these examples in a shared file (similar to how other operations usually have a run_grouped_gemm_xxxx_example.inc)

for(index_t i = 0; i < arg.group_count_; ++i)
{
const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably from the XDL implementation, but I think the a_mtx_mraw_kraw_ descriptor will contain the sum of Ms, not the actual M dimension. Then there would be a weird edge case that the sum of M would align with the vector load size, but individual Ms don't.

(But as you don't know the actual M, you can't verify it anyway, so maybe that check can be left out)

void* gemm_kernel_host_args_;
index_t grid_size_;
index_t grid_size_grp_;
index_t barrier_size_grp_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this unused?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants